-
Notifications
You must be signed in to change notification settings - Fork 162
import fix for torch 2.9 #315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Riyad Islam <[email protected]>
WalkthroughMoved ONNX internal imports to runtime, environment-aware locations: Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Pre-merge checks (3 passed)✅ Passed checks (3 passed)
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Signed-off-by: Riyad Islam <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/export_onnx.py (1)
186-189
: Bug: casting the input instead of the output.This block intends to cast the exported result back; currently it re-casts inputs. Cast
out
.Apply:
- if trt_high_precision_dtype != input_type: - inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[input_type]) + if trt_high_precision_dtype != input_type: + out = g.op("Cast", out, to_i=onnx_dtype_map[input_type])
🧹 Nitpick comments (3)
modelopt/torch/quantization/export_onnx.py (3)
116-118
: Avoid double-import of symbolic_helper; keep one alias and expose both names locally.Saves a redundant import while preserving existing references.
Apply:
-from torch.onnx import symbolic_helper -from torch.onnx import symbolic_helper as sym_help +from torch.onnx import symbolic_helper as sym_help +# Keep both names available for callers in this module. +symbolic_helper = sym_help
401-401
: Remove redundant re-import inside function.
_attention_scale
and_causal_attention_mask
are already imported at module scope.Apply:
- from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
109-115
: Add CI coverage across Torch versions for these private imports.Run exporter smoke tests on a matrix (e.g., 2.8.x and 2.9+), to catch future internal layout changes early.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/quantization/export_onnx.py
(1 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/nn/modules/tensor_quantizer.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
from packaging.version import Version | ||
|
||
if Version(torch.__version__) >= Version("2.9.0"): | ||
from torch.onnx._internal.torchscript_exporter import _type_utils | ||
else: | ||
from torch.onnx import _type_utils | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Prefer feature-detection over version checks for _type_utils (drop packaging dep).
Version-guarding is brittle (nightlies/alphas may differ). Try known locations in order and fall back, avoiding a hard dependency on packaging.
Apply:
-from packaging.version import Version
-
-if Version(torch.__version__) >= Version("2.9.0"):
- from torch.onnx._internal.torchscript_exporter import _type_utils
-else:
- from torch.onnx import _type_utils
+try:
+ # Newer PyTorch layouts first
+ from torch.onnx._internal.torchscript_exporter import _type_utils # type: ignore[attr-defined]
+except Exception:
+ try:
+ from torch.onnx._internal.exporter import _type_utils # type: ignore[attr-defined]
+ except Exception:
+ # Older layout
+ from torch.onnx import _type_utils
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from packaging.version import Version | |
if Version(torch.__version__) >= Version("2.9.0"): | |
from torch.onnx._internal.torchscript_exporter import _type_utils | |
else: | |
from torch.onnx import _type_utils | |
try: | |
# Newer PyTorch layouts first | |
from torch.onnx._internal.torchscript_exporter import _type_utils # type: ignore[attr-defined] | |
except Exception: | |
try: | |
from torch.onnx._internal.exporter import _type_utils # type: ignore[attr-defined] | |
except Exception: | |
# Older layout | |
from torch.onnx import _type_utils |
🤖 Prompt for AI Agents
In modelopt/torch/quantization/export_onnx.py around lines 109 to 115, replace
the packaging.version based import with feature-detection: attempt to import
_type_utils from torch.onnx._internal.torchscript_exporter in a try/except and
on ImportError fall back to importing _type_utils from torch.onnx; remove the
packaging.version import and any Version checks so the code tries known
locations in order and gracefully falls back for nightlies/alphas.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kevalmorabia97 which approach is better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If torch is installed from source before official 2.9 release (e.g. in DLFW Pytorch containers), it will have version like 2.9.0a0...nv
which using version check would show as <2.9.0
even though it has 2.9.0 changes. Lets change the version check to >2.8
instead of >=2.9
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #315 +/- ##
==========================================
- Coverage 73.85% 73.82% -0.04%
==========================================
Files 172 172
Lines 17430 17438 +8
==========================================
Hits 12873 12873
- Misses 4557 4565 +8 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
from torch import nn | ||
from torch.onnx._globals import GLOBALS | ||
|
||
if Version(torch.__version__) >= Version("2.9.0"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer moving this import down in the function where it is used since we are importing from torch.onnx._internal
which is a private function and may again change in the future potentially causing import to break
5f54471
to
fda052f
Compare
Signed-off-by: Riyad Islam <[email protected]>
fda052f
to
009865b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/export_onnx.py (1)
179-183
: Bug: casting inputs instead of the output in export_int8.This prevents restoring the original dtype on the returned tensor.
Fix:
- if trt_high_precision_dtype != input_type: - inputs = g.op("Cast", inputs, to_i=onnx_dtype_map[input_type]) + if trt_high_precision_dtype != input_type: + out = g.op("Cast", out, to_i=onnx_dtype_map[input_type])
♻️ Duplicate comments (1)
modelopt/torch/quantization/export_onnx.py (1)
293-333
: Use try/except import chain (and switch to sym_help) for robustness across PyTorch layouts.hasattr gates can be brittle. Prefer attempting known locations in order and falling back. Also swap symbolic_helper → sym_help for consistency with the module‑level alias.
Apply:
- if hasattr(torch.onnx, "_type_utils"): - from torch.onnx import _type_utils - else: - from torch.onnx._internal.torchscript_exporter import _type_utils + try: + # Newer layouts first + from torch.onnx._internal.torchscript_exporter import _type_utils # type: ignore[attr-defined] + except ImportError: + try: + from torch.onnx._internal.exporter import _type_utils # type: ignore[attr-defined] + except ImportError: + # Older layout + from torch.onnx import _type_utils @@ - assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( + assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), ( @@ - if symbolic_helper._is_none(scale): + if sym_help._is_none(scale): @@ - key_shape_builtin = symbolic_helper._get_tensor_rank(key) + key_shape_builtin = sym_help._get_tensor_rank(key) @@ - if symbolic_helper._is_none(attn_mask): + if sym_help._is_none(attn_mask):
🧹 Nitpick comments (3)
modelopt/torch/quantization/export_onnx.py (3)
109-110
: Deduplicate import; standardize on a single alias.Both symbolic_helper and its alias are imported. Keep only the alias and use it consistently.
Apply:
-from torch.onnx import symbolic_helper from torch.onnx import symbolic_helper as sym_help
Then replace symbolic_helper usages in this file with sym_help (see suggested edits below in the SDPA block).
399-399
: Remove redundant in-function import (already imported at module scope)._this function re-imports _attention_scale and causal_attention_mask despite the top-level import. Safe to drop.
- from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
401-405
: DRY the _type_utils fallback logic (reuse via a tiny helper).The same conditional import appears twice. Factor it once and reuse to avoid drift.
Apply in this function:
- if hasattr(torch.onnx, "_type_utils"): - from torch.onnx import _type_utils - else: - from torch.onnx._internal.torchscript_exporter import _type_utils + _type_utils = _get_type_utils()Add this helper once near the top of the file (after imports):
+from functools import lru_cache + +@lru_cache(None) +def _get_type_utils(): + try: + from torch.onnx._internal.torchscript_exporter import _type_utils # type: ignore[attr-defined] + return _type_utils + except ImportError: + try: + from torch.onnx._internal.exporter import _type_utils # type: ignore[attr-defined] + return _type_utils + except ImportError: + from torch.onnx import _type_utils + return _type_utils
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/quantization/export_onnx.py
(3 hunks)modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/nn/modules/tensor_quantizer.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
881-885
: Previous feedback addressed—import moved into the function.Thanks for inlining the ONNX import into
forward
, aligning with the prior review note about private internals potentially changing.
🧹 Nitpick comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
881-885
: Import only when needed to minimize Dynamo interactions.Move the import under the existing guard before using
GLOBALS
to avoid importing private ONNX internals on every forward and reduce chances of Dynamo graph breaks.Example:
# ... keep top of forward() unchanged # Replace the later check with a lazy import: if not is_torch_export_mode() and not torch._dynamo.is_compiling(): try: from torch.onnx._globals import GLOBALS except Exception: try: from torch.onnx._internal.torchscript_exporter._globals import GLOBALS except Exception: class GLOBALS: # noqa: N801 in_onnx_export = False if GLOBALS.in_onnx_export: self._check_onnx_readiness(inputs)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (1)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
881-885
: Verify ONNX GLOBALS import across supported Torch versions.Confirm which Torch versions expose torch.onnx._globals vs torch.onnx._internal.torchscript_exporter._globals so the fallback doesn't mask required checks.
File: modelopt/torch/quantization/nn/modules/tensor_quantizer.py (lines 881-885)
Local run failed with ModuleNotFoundError (torch not installed) — run the following in an environment with torch installed (repeat for each supported Torch version / CI matrix) and paste the output:
python - <<'PY' import importlib, importlib.util, sys try: import torch except Exception as e: print("ERROR: torch not installed:", e) sys.exit(2) print("torch:", getattr(torch, "__version__", "<unknown>")) cands = [ "torch.onnx._globals", "torch.onnx._internal.torchscript_exporter._globals", ] for m in cands: spec = importlib.util.find_spec(m) print(f"{m} -> find_spec: {bool(spec)}; origin: {spec.origin if spec else None}") try: mod = importlib.import_module(m) print(" import OK; file:", getattr(mod, "__file__", "<built-in>")) print(" GLOBALS present:", hasattr(mod, "GLOBALS")) except Exception as e: print(" import failed:", repr(e)) PYReport: torch version and, for each candidate, whether find_spec/import succeeded, module file path, and whether GLOBALS exists.
if hasattr(torch.onnx, "_globals"): | ||
from torch.onnx._globals import GLOBALS | ||
else: | ||
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make ONNX GLOBALS import robust (avoid hasattr; gracefully handle older Torch).
hasattr(torch.onnx, "_globals")
can be false‑positive/negative depending on lazy imports; also, neither path may exist on some Torch versions, which would currently raise ImportError on every forward even when ONNX isn’t used. Use try/fallback and provide a safe sentinel.
Apply this diff:
- if hasattr(torch.onnx, "_globals"):
- from torch.onnx._globals import GLOBALS
- else:
- from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
+ try:
+ from torch.onnx._globals import GLOBALS
+ except Exception:
+ try:
+ from torch.onnx._internal.torchscript_exporter._globals import GLOBALS
+ except Exception:
+ # Fallback for Torch versions lacking both paths; skip ONNX readiness checks.
+ class GLOBALS: # noqa: N801
+ in_onnx_export = False
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if hasattr(torch.onnx, "_globals"): | |
from torch.onnx._globals import GLOBALS | |
else: | |
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS | |
try: | |
from torch.onnx._globals import GLOBALS | |
except Exception: | |
try: | |
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS | |
except Exception: | |
# Fallback for Torch versions lacking both paths; skip ONNX readiness checks. | |
class GLOBALS: # noqa: N801 | |
in_onnx_export = False |
🤖 Prompt for AI Agents
In modelopt/torch/quantization/nn/modules/tensor_quantizer.py around lines 881
to 885, the current hasattr-based import for ONNX GLOBALS is brittle and may
raise ImportError on some Torch versions; replace it with a try/except
ImportError sequence: try to import GLOBALS from torch.onnx._globals, if that
fails try torch.onnx._internal.torchscript_exporter._globals, and if both fail
set GLOBALS to a safe sentinel (e.g., None or an object) so forward passes won’t
error when ONNX isn’t used.
Signed-off-by: Riyad Islam <[email protected]> Signed-off-by: Ye Yu <[email protected]>
Signed-off-by: Riyad Islam <[email protected]> Signed-off-by: Ye Yu <[email protected]>
What does this PR do?
Type of change: Bug fix
Overview: import fix for torch 2.9+ installed from source
Usage
# Add a code snippet demonstrating how to use this
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit